# ------------------------------------------------------------------------------
# Simulate dropping tests
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 10_simulating_testing.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(testit) # assert()
library(glue) # glue strings
library(ggplot2)
library(ggthemes) # colorblind
library(broom) # tidy

temp <- here::here("code", "05_natural_experiment", "temp")
overnight_lab <- ""

a <- modules::use(here::here("lib", "aesthetics.R"))
u <- modules::use(here::here("lib", "util.R"))
a$get_font("Optima", here("lib", "optima.ttf"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
life_exp <- readRDS(paths$analysis$life_expectancy)
longterm_outcomes <- readRDS(paths$analysis$longterm_outcomes) %>%
  select(ed_enc_id, death_365_day, macetrop_31_to_365_pos)
daly_cost <- readRDS(paths$analysis$daly_cost) %>%
  select(ed_enc_id, cost, daly_100, daly_200, daly_250, daly_300, daly_500)
risk_cutoff <- readRDS(paths$analysis$costeff_risk_cutoff)

cohort <- readRDS(file.path(temp, "shift_bins_df.rds")) %>%
  filter(!exclude) %>%
  u$safe_left_join(life_exp) %>%
  u$safe_left_join(longterm_outcomes) %>%
  u$safe_left_join(daly_cost) %>%
  mutate(
    yhat = p__ensemble__stent_or_cabg_010_day__tested,
    yhat_bin = factor(tile_stent_or_cabg_005_tested),
    shift_bin = bin_re_shift_12_full_test_010_day_inclyhat_linear,
    macetrop_31_to_365_pos_excl = macetrop_31_to_365_pos & !macetrop_030_pos,
  ) %>%
  filter(split == "train" | split == "test")

# Group by Bin
subcohort <- filter(cohort, yhat_bin == 5)
top_test_rate <- filter(subcohort, shift_bin == 4) %>%
  .[["test_010_day"]] %>%
  mean
marginal_tests <- subcohort %>%
  group_by(shift_bin) %>%
  summarize(
    n = n(),
    test_rate = mean(test_010_day),
    total_costs = sum(cost),
    costs_per_test = total_costs/sum(test_010_day),
    death_365_rate = mean(death_365_day),
    macetrop_31_365_excl_rate = mean(macetrop_31_to_365_pos_excl),
    marginal_test_rate = top_test_rate - test_rate,
    marginal_test_n = marginal_test_rate*n,
    mean_risk = mean(p__ensemble__stent_or_cabg_010_day__tested),
    exp_ly_remaining = mean(ly_remaining)
  ) %>%
  ungroup %>%
  mutate(
    death_pp_delta = 0.0379*marginal_test_rate - 0.4266*mean_risk*marginal_test_rate,
    n_death_delta = n*death_pp_delta,
    death_ly_value = n_death_delta*exp_ly_remaining,
    death_dollar_value = death_ly_value*150000,
    mace_pp_delta = 0.0388*marginal_test_rate - 0.2512*mean_risk*marginal_test_rate,
    n_mace_delta = n*mace_pp_delta,
    mace_ly_value = n_mace_delta*exp_ly_remaining*0.125,
    mace_dollar_value = mace_ly_value*150000
  )

message("Q5 sample results")
print(marginal_tests %>% select(shift_bin, n, marginal_test_rate, death_ly_value, mace_ly_value))
write_csv(marginal_tests, file.path(temp, "simulating_tests.csv"))

total_marg_tests <- sum(marginal_tests$marginal_test_n)
num_untested <- sum(!cohort$test_010_day)
num_tested <- sum(cohort$test_010_day)
message("Num added tests: ", total_marg_tests)
message("Pct of untested who we now test: ", total_marg_tests/num_untested)
message("Untested we now test as pctge of all tested: ", total_marg_tests/num_tested)
message("")

old_trate = mean(subcohort$test_010_day)
new_trate = marginal_tests$test_rate[4]
pct_pt_change = new_trate - old_trate
pct_change = pct_pt_change/old_trate

message("Original Q5 test rate: ", old_trate)
message("New Q5 test rate: ", new_trate)
message("Pctage pt testing increase in Q5: ", pct_pt_change)
message("Pct testing increase in Q5: ", pct_change)

# Simulations for cohort above cost-effectiveness threshold --------------------
message("------------------")
message("Getting simulations for subset of cohort above costeff threshold...")
costeff_risk_cutoff <- readRDS(paths$analysis$costeff_risk_cutoff)
subcohort_costeff <- filter(
  cohort,
  split == "test" &
  p__ensemble__stent_or_cabg_010_day__tested > costeff_risk_cutoff
)
top_test_rate_costeff <- filter(subcohort_costeff, shift_bin == 4) %>%
  .[["test_010_day"]] %>%
  mean
marginal_tests_costeff <- subcohort_costeff %>%
  group_by(shift_bin) %>%
  summarize(
    n = n(),
    test_rate = mean(test_010_day),
    total_costs = sum(cost),
    costs_per_test = total_costs/sum(test_010_day),
    death_365_rate = mean(death_365_day),
    macetrop_31_365_excl_rate = mean(macetrop_31_to_365_pos_excl),
    marginal_test_rate = top_test_rate - test_rate,
    marginal_test_n = marginal_test_rate*n,
    mean_risk = mean(p__ensemble__stent_or_cabg_010_day__tested),
    exp_ly_remaining = mean(ly_remaining)
  ) %>%
  ungroup %>%
  mutate(
    death_pp_delta = 0.0379*marginal_test_rate - 0.4266*mean_risk*marginal_test_rate,
    n_death_delta = n*death_pp_delta,
    death_ly_value = n_death_delta*exp_ly_remaining,
    death_dollar_value = death_ly_value*150000,
    mace_pp_delta = 0.0388*marginal_test_rate - 0.2512*mean_risk*marginal_test_rate,
    n_mace_delta = n*mace_pp_delta,
    mace_ly_value = n_mace_delta*exp_ly_remaining*0.125,
    mace_dollar_value = mace_ly_value*150000
  )

print(marginal_tests_costeff %>% select(shift_bin, n, marginal_test_rate, death_ly_value, mace_ly_value))

total_marg_tests_costeff <- sum(marginal_tests_costeff$marginal_test_n)
num_untested <- sum(!cohort$test_010_day)
num_tested <- sum(cohort$test_010_day)
message("Cost-effective sample results")
message("Num added tests: ", total_marg_tests_costeff)
message("Pct of untested who we now test: ", total_marg_tests_costeff/num_untested)
message("Untested we now test as pctge of all tested: ", total_marg_tests_costeff/num_tested)
message("")

old_trate = mean(subcohort_costeff$test_010_day)
new_trate = marginal_tests_costeff$test_rate[4]
pct_pt_change = new_trate - old_trate
pct_change = pct_pt_change/old_trate

message("Original cost-eff test rate: ", old_trate)
message("New cost-eff test rate: ", new_trate)
message("Pctage pt testing increase in cost-eff: ", pct_pt_change)
message("Pct testing increase in cost-eff: ", pct_change)

message("Done.")
